{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Face detection for video\n",
    "Images of detected faces have format `frameXfaceY.jpg`, where `X` represents the Xth frame and `Y` the Yth face in Xth frame. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import cv2\n",
    "import numpy as np\n",
    "from matplotlib import pyplot as plt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import tensorflow as tf\n",
    "from keras import backend as K\n",
    "from pathlib import PurePath, Path\n",
    "from moviepy.editor import VideoFileClip"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "from umeyama import umeyama\n",
    "import mtcnn_detect_face"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Create MTCNN and its forward pass functions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "def create_mtcnn(sess, model_path):\n",
    "    if not model_path:\n",
    "        model_path,_ = os.path.split(os.path.realpath(__file__))\n",
    "\n",
    "    with tf.variable_scope('pnet2'):\n",
    "        data = tf.placeholder(tf.float32, (None,None,None,3), 'input')\n",
    "        pnet = mtcnn_detect_face.PNet({'data':data})\n",
    "        pnet.load(os.path.join(model_path, 'det1.npy'), sess)\n",
    "    with tf.variable_scope('rnet2'):\n",
    "        data = tf.placeholder(tf.float32, (None,24,24,3), 'input')\n",
    "        rnet = mtcnn_detect_face.RNet({'data':data})\n",
    "        rnet.load(os.path.join(model_path, 'det2.npy'), sess)\n",
    "    with tf.variable_scope('onet2'):\n",
    "        data = tf.placeholder(tf.float32, (None,48,48,3), 'input')\n",
    "        onet = mtcnn_detect_face.ONet({'data':data})\n",
    "        onet.load(os.path.join(model_path, 'det3.npy'), sess)\n",
    "    return pnet, rnet, onet"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "WEIGHTS_PATH = \"./mtcnn_weights/\"\n",
    "\n",
    "sess = K.get_session()\n",
    "with sess.as_default():\n",
    "    global pnet, rnet, onet \n",
    "    pnet, rnet, onet = create_mtcnn(sess, WEIGHTS_PATH)\n",
    "\n",
    "global pnet, rnet, onet\n",
    "    \n",
    "pnet = K.function([pnet.layers['data']],[pnet.layers['conv4-2'], pnet.layers['prob1']])\n",
    "rnet = K.function([rnet.layers['data']],[rnet.layers['conv5-2'], rnet.layers['prob1']])\n",
    "onet = K.function([onet.layers['data']],[onet.layers['conv6-2'], onet.layers['conv6-3'], onet.layers['prob1']])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Create folder where images will be saved to"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "Path(f\"faces/aligned_faces\").mkdir(parents=True, exist_ok=True)\n",
    "Path(f\"faces/raw_faces\").mkdir(parents=True, exist_ok=True)\n",
    "Path(f\"faces/binary_masks_eyes\").mkdir(parents=True, exist_ok=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Functions for video processing and face alignment"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_src_landmarks(x0, x1, y0, y1, pnts):\n",
    "    \"\"\"\n",
    "    x0, x1, y0, y1: (smoothed) bbox coord.\n",
    "    pnts: landmarks predicted by MTCNN\n",
    "    \"\"\"    \n",
    "    src_landmarks = [(int(pnts[i+5][0]-x0), \n",
    "                      int(pnts[i][0]-y0)) for i in range(5)]\n",
    "    return src_landmarks\n",
    "\n",
    "def get_tar_landmarks(img):\n",
    "    \"\"\"    \n",
    "    img: detected face image\n",
    "    \"\"\"         \n",
    "    ratio_landmarks = [\n",
    "        (0.31339227236234224, 0.3259269274198092),\n",
    "        (0.31075140146108776, 0.7228453709528997),\n",
    "        (0.5523683107816256, 0.5187296867370605),\n",
    "        (0.7752419985257663, 0.37262483743520886),\n",
    "        (0.7759613623985877, 0.6772957581740159)\n",
    "        ]   \n",
    "        \n",
    "    img_size = img.shape\n",
    "    tar_landmarks = [(int(xy[0]*img_size[0]), \n",
    "                      int(xy[1]*img_size[1])) for xy in ratio_landmarks]\n",
    "    return tar_landmarks\n",
    "\n",
    "def landmarks_match_mtcnn(src_im, src_landmarks, tar_landmarks): \n",
    "    \"\"\"\n",
    "    umeyama(src, dst, estimate_scale)\n",
    "    landmarks coord. for umeyama should be (width, height) or (y, x)\n",
    "    \"\"\"\n",
    "    src_size = src_im.shape\n",
    "    src_tmp = [(int(xy[1]), int(xy[0])) for xy in src_landmarks]\n",
    "    tar_tmp = [(int(xy[1]), int(xy[0])) for xy in tar_landmarks]\n",
    "    M = umeyama(np.array(src_tmp), np.array(tar_tmp), True)[0:2]\n",
    "    result = cv2.warpAffine(src_im, M, (src_size[1], src_size[0]), borderMode=cv2.BORDER_REPLICATE) \n",
    "    return result\n",
    "\n",
    "def process_mtcnn_bbox(bboxes, im_shape):\n",
    "    \"\"\"\n",
    "    output bbox coordinate of MTCNN is (y0, x0, y1, x1)\n",
    "    Here we process the bbox coord. to a square bbox with ordering (x0, y1, x1, y0)\n",
    "    \"\"\"\n",
    "    for i, bbox in enumerate(bboxes):\n",
    "        y0, x0, y1, x1 = bboxes[i,0:4]\n",
    "        w, h = int(y1 - y0), int(x1 - x0)\n",
    "        length = (w + h)/2\n",
    "        center = (int((x1+x0)/2),int((y1+y0)/2))\n",
    "        new_x0 = np.max([0, (center[0]-length//2)])#.astype(np.int32)\n",
    "        new_x1 = np.min([im_shape[0], (center[0]+length//2)])#.astype(np.int32)\n",
    "        new_y0 = np.max([0, (center[1]-length//2)])#.astype(np.int32)\n",
    "        new_y1 = np.min([im_shape[1], (center[1]+length//2)])#.astype(np.int32)\n",
    "        bboxes[i,0:4] = new_x0, new_y1, new_x1, new_y0\n",
    "    return bboxes\n",
    "\n",
    "def process_video(input_img): \n",
    "    global frames, save_interval\n",
    "    global pnet, rnet, onet\n",
    "    minsize = 30 # minimum size of face\n",
    "    detec_threshold = 0.7\n",
    "    threshold = [0.6, 0.7, detec_threshold]  # three steps's threshold\n",
    "    factor = 0.709 # scale factor   \n",
    "    \n",
    "    frames += 1    \n",
    "    if frames % save_interval == 0:\n",
    "        faces, pnts = mtcnn_detect_face.detect_face(\n",
    "            input_img, minsize, pnet, rnet, onet, threshold, factor)\n",
    "        faces = process_mtcnn_bbox(faces, input_img.shape)\n",
    "        \n",
    "        for idx, (x0, y1, x1, y0, conf_score) in enumerate(faces):\n",
    "            det_face_im = input_img[int(x0):int(x1),int(y0):int(y1),:]\n",
    "\n",
    "            # get src/tar landmarks\n",
    "            src_landmarks = get_src_landmarks(x0, x1, y0, y1, pnts)\n",
    "            tar_landmarks = get_tar_landmarks(det_face_im)\n",
    "\n",
    "            # align detected face\n",
    "            aligned_det_face_im = landmarks_match_mtcnn(\n",
    "                det_face_im, src_landmarks, tar_landmarks)\n",
    "\n",
    "            fname = f\"./faces/aligned_faces/frame{frames}face{str(idx)}.jpg\"\n",
    "            plt.imsave(fname, aligned_det_face_im, format=\"jpg\")\n",
    "            fname = f\"./faces/raw_faces/frame{frames}face{str(idx)}.jpg\"\n",
    "            plt.imsave(fname, det_face_im, format=\"jpg\")\n",
    "            \n",
    "            bm = np.zeros_like(aligned_det_face_im)\n",
    "            h, w = bm.shape[:2]\n",
    "            bm[int(src_landmarks[0][0]-h/15):int(src_landmarks[0][0]+h/15),\n",
    "               int(src_landmarks[0][1]-w/8):int(src_landmarks[0][1]+w/8),:] = 255\n",
    "            bm[int(src_landmarks[1][0]-h/15):int(src_landmarks[1][0]+h/15),\n",
    "               int(src_landmarks[1][1]-w/8):int(src_landmarks[1][1]+w/8),:] = 255\n",
    "            bm = landmarks_match_mtcnn(bm, src_landmarks, tar_landmarks)\n",
    "            fname = f\"./faces/binary_masks_eyes/frame{frames}face{str(idx)}.jpg\"\n",
    "            plt.imsave(fname, bm, format=\"jpg\")\n",
    "        \n",
    "    return np.zeros((3,3,3))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Start face detection\n",
    "\n",
    "Default input video filename: `INPUT_VIDEO.mp4`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "global frames\n",
    "frames = 0\n",
    "\n",
    "# configuration\n",
    "save_interval = 6 # perform face detection every {save_interval} frames\n",
    "fn_input_video = \"INPUT_VIDEO.mp4\"\n",
    "\n",
    "output = 'dummy.mp4'\n",
    "clip1 = VideoFileClip(fn_input_video)\n",
    "clip = clip1.fl_image(process_video)#.subclip(0,3) #NOTE: this function expects color images!!\n",
    "clip.write_videofile(output, audio=False)\n",
    "clip1.reader.close()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Saved images will be in folder `faces/raw_faces` and `faces/aligned_faces` respectively. Binary masks will be in `faces/binary_masks_eyes`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
